
### #######################################
### Does variable selection using glmnet
### Output is a formula
### #######################################

here::i_am("newR/09_varsel_alr.R")

library(here)
library(glmnet)
library(tidyverse)
library(doMC)
library(knitr)
library(broom)
library(parallel)

set.seed(88516)

glmnet_to_formula <- function(obj, depvar = "y") {
    terms <- rownames(obj[[1]])[which(obj[[1]] != 0)]
    ## Remove intercept
    terms <- terms[-1]
    ## Handle Scotland and Wales
    terms <- sub("NationScotland", "Nation", terms)
    terms <- sub("NationWales", "Nation", terms)
    
    terms <- paste0(unique(terms), collapse = " + ")
    f <- paste0(depvar,
                " ~ ",
                terms)
    as.formula(f)
}

ncores <- parallel::detectCores() - 2
registerDoMC(cores = ncores)

source(here::here("newR", "00_helpers.R"))

dat <- readRDS(here::here("working",
                          "tidy_model_data.rds"))

pvars <- c("Con_BE", "Lab_BE", "Lib_BE", "Nat_BE", "Oth2_BE")
for (p in pvars) { 
    dat[,p] <- replace(dat[,p],
                       dat[,p] <= 0,
                       1 / 40000)
}

### Make sure everything sums to 1
dat[,pvars] <- dat[,pvars] / rowSums(dat[,pvars])

y <- as.matrix(cbind(dat[,pvars]))

### ALR-transform y
y <- log(y / y[,1])

### Create matrix with all pairwise interactions in it for glmnet
x <- model.matrix(~ (Lab_GE_sc + Lib_GE_sc + Nat_GE_sc + Oth_GE_sc +
                     LabGovt_sc + LibGovt_sc +
                     LabInc_sc + LibInc_sc + RPI_sc + winter_sc +
                     Turn_at_GE_sc + Nation + 
                     LabPollChg_sc + LibPollChg_sc + OthPollChg_sc + 
                     ConCand_BE_sc + LabCand_BE_sc +
                     LibCand_BE_sc + NatCand_BE_sc + OthCand_BE_sc +
                     ConCand_GE_sc + LabCand_GE_sc +
                     LibCand_GE_sc + NatCand_GE_sc + OthCand_GE_sc +
                     IncCandidateLib_sc + IncCandidateOther_sc) ^ 2,
                  data = dat)


### Do the cross-validation

mfit <- cv.glmnet(x, y[,-1], family = "mgaussian")

mft <- tidy(mfit)
vars_1 <- mfit$nzero[which(mfit$lambda == mfit$lambda.min)]
vars_2 <- mfit$nzero[which(mfit$lambda == mfit$lambda.1se)]
vars_3 <- mfit$nzero[19]

p1 <- ggplot(mft,
             aes(x = log(lambda),
                 y = estimate,
                 ymin = conf.low,
                 ymax = conf.high)) +
    geom_vline(xintercept = log(mfit$lambda.min), linetype = 2) +
    geom_vline(xintercept = log(mfit$lambda.1se), linetype = 2) +
    geom_vline(xintercept = log(mfit$lambda[19]), linetype = 1, colour = "red") +
    annotate(geom = "text",
             x = log(mfit$lambda.min),
             y = .6,
             size = 3,
             label = paste0("(A) \n",vars_1, " \nvariables \nselected "),
             hjust = "right") + 
    annotate(geom = "text",
             x = log(mfit$lambda.1se),
             y = .6,
             size = 3,
             label = paste0(" (B)\n ", vars_2, "\n variables\n selected"),
             hjust = "left") + 
    annotate(geom = "text",
             x = log(mfit$lambda[19]),
             y = .6,
             size = 3,
             label = paste0("(C) \n", vars_3, " \nvariables \nselected "),
             hjust = "right") + 
    geom_pointrange(colour = "darkgrey", fatten = 0.5) +
    scale_x_continuous("log(λ)") +
    scale_y_log10("MSE") + 
    theme_bw()


lambdas <- mfit$lambda

cobj <- coef(mfit, s = "lambda.min")
varsel_f <- glmnet_to_formula(cobj)


save(varsel_f,
     file = here::here("working",
                       "varsel_formula_alr.rdata"))
